/**
 * 
 */
package edu.berkeley.nlp.discPCFG;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;

import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.GrammarMerger;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction.Counts;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.ScalingTools;

/**
 * @author petrov
 *
 */
public class ConditionalMerger {
	int nProcesses;
	String consBaseName;
  Grammar grammar;
  Lexicon lexicon;
  double mergingPercentage;
  String outFileName;
  
	StateSetTreeList[] trainingTrees;
  ExecutorService pool;
  Merger[] tasks;
  double[][] mergeWeights;

	
	
  class Merger implements Callable{
  	ArrayParser gParser;
  	ConstrainedTwoChartsParser eParser;
  	StateSetTreeList myTrees;
  	String consName;
  	int myID;
    int nCounts;
  	boolean[][][][][] myConstraints;
  	int unparsableTrees, incorrectLLTrees;
  	double[][] mergeWeights;

  	Merger(StateSetTreeList myT, String consN, int i, Grammar gr, Lexicon lex, double[][] mergeWeights){
  		this.consName = consN;
  		this.myTrees = myT;
  		this.myID = i;
  		this.mergeWeights = mergeWeights;
    	gParser = new ArrayParser(gr, lex);
  		eParser = new ConstrainedTwoChartsParser(gr, lex, null);
  	}
  
   	private void loadConstraints(){
  		myConstraints = new boolean[myTrees.size()][][][][];
  		boolean[][][][][] curBlock = null;
  		int block = 0;
  		int i = 0;
  		if (consName==null) return;
  		for (int tree=0; tree<myTrees.size(); tree++){
  			if (curBlock == null || i >= curBlock.length){
	  			int blockNumber = ((block*nProcesses)+myID);
	  			curBlock = loadData(consName+"-"+blockNumber+".data");
	  			block++;
	  			i = 0;
	  			System.out.print(".");
	  		}
  			eParser.projectConstraints(curBlock[i]);
  			myConstraints[tree] = curBlock[i];
  			i++;
  			if (myConstraints[tree].length!=myTrees.get(tree).getYield().size()){
    			System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
    			System.out.println("Sentence length and constraints length do not match!");
    			myConstraints[tree] = null;
  			}
  		}

  	}

	  public double[][][] call() {
			if (myConstraints==null) loadConstraints();
			double[][][] deltas = new double[grammar.numStates][mergeWeights[0].length][mergeWeights[0].length];
	  	int i = -1;
	  	int block = 0;
	  	for (Tree<StateSet> stateSetTree : myTrees) {
	    	i++;
	    	boolean noSmoothing = true, debugOutput = false, hardCounts = false;
	    	gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput);
	  		
	  		// parse the sentence
	  		List<StateSet> yield = stateSetTree.getYield();
	    	List<String> sentence = new ArrayList<String>(yield.size());
	    	for (StateSet el : yield){ sentence.add(el.getWord()); }
	    	boolean[][][][] cons = null;
	    	if (consName!=null){
	    		cons = myConstraints[i];
	    		if (cons.length != sentence.size()){
	    			System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i);
	    			System.out.println("Sentence length ("+sentence.size()+") and constraints length ("+cons.length+") do not match!");
	    			System.exit(-1);
	    		}
	    	}
	    	eParser.doConstrainedInsideOutsideScores(yield,cons,noSmoothing,stateSetTree,null,false);
	    	
	    	eParser.tallyConditionalLoss(stateSetTree, deltas, mergeWeights);

	    	if (i%100==0) System.out.print(".");
	    }
	    
	    System.out.print(" "+myID+" ");
	    return deltas;
	  }
	  
	  public boolean[][][][][] loadData(String fileName) {
	  	boolean[][][][][] data = null;
	    try {
	      FileInputStream fis = new FileInputStream(fileName); // Load from file
	      GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
	      ObjectInputStream in = new ObjectInputStream(gzis); // Load objects
	      data = (boolean[][][][][])in.readObject(); // Read the mix of grammars
	      in.close(); // And close the stream.
	    } catch (IOException e) {
	      System.out.println("IOException\n"+e);
	      return null;
	    } catch (ClassNotFoundException e) {
	      System.out.println("Class not found!");
	      return null;
	    }
	    return data;
	  }
	
  }



	/**
	 * @param processes
	 * @param consBaseName
	 * @param trainingTrees
	 */
	public ConditionalMerger(int processes, String consBaseName, StateSetTreeList trainTrees,
				Grammar gr, Lexicon lex, double mergingPercentage, String outFileName) {
		this.nProcesses = processes;
		this.consBaseName = consBaseName;
		this.grammar = gr;//.copyGrammar();
		this.lexicon = lex;//.copyLexicon();
		this.mergingPercentage = mergingPercentage;
		this.outFileName = outFileName;
		
	  int nTreesPerBlock = trainTrees.size()/processes;
	  this.consBaseName = consBaseName;
	  boolean[][][][][] tmp = edu.berkeley.nlp.PCFGLA.ParserConstrainer.loadData(consBaseName+"-0.data");
	  if (tmp!=null) nTreesPerBlock = tmp.length;

	  // first compute the generative merging criterion
		mergeWeights = GrammarMerger.computeMergeWeights(grammar, lexicon,trainTrees);
		double[][][] deltas = GrammarMerger.computeDeltas(grammar, lexicon, mergeWeights, trainTrees);
		boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas,false,mergingPercentage,grammar);
		Grammar tmpGrammar = grammar.copyGrammar(true);
		Lexicon tmpLexicon = lexicon.copyLexicon();
		tmpGrammar = GrammarMerger.doTheMerges(tmpGrammar, tmpLexicon, mergeThesePairs, mergeWeights);
		System.out.println("Generative merging criterion gives:");
		GrammarMerger.printMergingStatistics(grammar, tmpGrammar);
		mergeWeights = GrammarMerger.computeMergeWeights(grammar, lexicon,trainTrees);

	  
	  // split the trees into chunks
	  trainingTrees = new StateSetTreeList[nProcesses];
	  for (int i=0; i<nProcesses; i++){
	  	trainingTrees[i] = new StateSetTreeList();
	  }
	  int block = -1;
	  int inBlock = 0;
	  for (int i=0; i<trainTrees.size(); i++){
	  	if (i%nTreesPerBlock==0) {
	  		block++;
	  		System.out.println(inBlock);
	  		inBlock = 0;
	  	}
	  	trainingTrees[block%nProcesses].add(trainTrees.get(i));
	  	inBlock++;
	  }
	  trainTrees = null;
	  pool = Executors.newFixedThreadPool(nProcesses);//CachedThreadPool();

	  tasks = new Merger[nProcesses];
		for (int i=0; i<nProcesses; i++){
			tasks[i] = new Merger(trainingTrees[i],consBaseName,i, grammar, lexicon, mergeWeights);
		}

	}
	
	public void mergeGrammarAndLexicon(){
		System.out.print("Task: ");
		Future[] submits = new Future[nProcesses];
		for (int i=0; i<nProcesses; i++){
			Future submit = pool.submit(tasks[i]);//execute(tasks[i]);
			submits[i] = submit;
		}
		
		while (true) {
			boolean done = true;
			for (Future task : submits) { done &= task.isDone(); }
			if (done) break;
		}

		// accumulate
		double[][][] deltas = new double[grammar.numStates][mergeWeights[0].length][mergeWeights[0].length];
		for (int i=0; i<nProcesses; i++){
			double[][][] counts = null;
			try {
				 counts = (double[][][]) submits[i].get();
			} catch (ExecutionException e) {
				e.printStackTrace();
			} catch (InterruptedException e) {
				e.printStackTrace();
			}
			for (int a=0; a<deltas.length; a++){
				for (int b=0; b<deltas[0].length; b++){
					for (int c=0; c<deltas[0][0].length; c++){
						deltas[a][b][c] += counts[a][b][c];
					}
				}
			}
		}
		System.out.print(" done. ");
		System.out.println("Conditional merging criterion gives:");
		boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas,false,mergingPercentage,grammar);
		Grammar newGrammar = GrammarMerger.doTheMerges(grammar, lexicon, mergeThesePairs, mergeWeights);
		GrammarMerger.printMergingStatistics(grammar, newGrammar);

		ParserData pData = new ParserData(lexicon, newGrammar, null, Numberer.getNumberers(), newGrammar.numSubStates, 1, 0, Binarization.RIGHT);
    System.out.println("Saving grammar to "+outFileName+".");
    if (pData.Save(outFileName+"-merged")) System.out.println("Saving successful.");
    else System.out.println("Saving failed!");

	}

}
